import mysql.connector
import random
import nltk
from nltk import word_tokenize, pos_tag
from nltk.corpus import wordnet
nltk.download('averaged_perceptron_tagger')
import heapq

def dbConnection():
    mydb = mysql.connector.connect(
        host="csmysql.cs.cf.ac.uk",
        user="d1101011",
        passwd="sunnyandsunny",
        database="d1101011"
    )

    mycursor = mydb.cursor()
    mydb.autocommit = True

    return mydb, mycursor


# close database connection
def finish(mydb, mycursor):
    mycursor.close()
    mydb.close()

def baseCateg(label):
    result = []
    mydb, mycursor = dbConnection()
    #sql = "SELECT sentId from safeguardSentencesReal " \
    #      "where experiment is null " \
    #      "and overallLabel = '"+label+"';"

    sql = "SELECT sentId from safeguardSentencesReal " \
          "where experiment = 'base' " \
          "and overallLabel = '" + label + "';"
    mycursor.execute(sql)
    res = mycursor.fetchall()
    finish(mydb, mycursor)

    for row in res:
        result.append(row[0])

    sample = random.sample(result,5)

    mydb, mycursor = dbConnection()
    for s in sample:
        sql = "UPDATE safeguardSentencesReal " \
              "SET experiment = 'base', categ5= '1' WHERE sentId = '"+s+"';"
        mycursor.execute(sql)
    finish(mydb, mycursor)

def randomCateg(l):
    result = []
    mydb, mycursor = dbConnection()
    #sql = "SELECT sentId from safeguardSentencesReal " \
    #      "where experiment is null " \
    #      "and overallLabel = '"+l+"';"

    sql = "SELECT sentId from safeguardSentencesReal " \
          "where experiment = 'random' and categ10 = '1'" \
          "and overallLabel = '" + l + "';"

    mycursor.execute(sql)
    res = mycursor.fetchall()
    finish(mydb, mycursor)

    for row in res:
        result.append(row[0])

    sample = random.sample(result,5)

    mydb, mycursor = dbConnection()
    for s in sample:
        sql = "UPDATE safeguardSentencesReal " \
              "SET experiment = 'random', categ5= '1' WHERE sentId = '"+s+"';"
        mycursor.execute(sql)
    finish(mydb, mycursor)

def getTags(document):
    tokens = [nltk.word_tokenize(sent) for sent in [document]]
    postag = [nltk.pos_tag(sent) for sent in tokens][0]

    nouns = []
    for i in range(0, len(postag)):
        if postag[i][1] in ['NN', 'NNS', 'NNP', 'NNPS']:
            if postag[i][0].lower() not in nouns:
                nouns.append(postag[i][0].lower())

    return nouns

def selectMaxNouns(alldata):
    nounsCount = {}
    for did in range(0,len(alldata)):
        document = alldata[did][1]
        nouns = getTags(document)

        nounsCount[alldata[did][0]] = len(nouns)

    maxsample = heapq.nlargest(5, nounsCount, key=nounsCount.get)
    return maxsample

def nounsCateg(l):
    result = []
    mydb, mycursor = dbConnection()
    sql = "SELECT sentId,sentence from safeguardSentencesReal " \
          "where experiment = 'nouns' and categ10='1' " \
          "and overallLabel = '"+l+"';"
    mycursor.execute(sql)
    res = mycursor.fetchall()
    finish(mydb, mycursor)
    for row in res:
        result.append([row[0],row[1]])

    nounssamp = selectMaxNouns(result)

    mydb, mycursor = dbConnection()
    for s in nounssamp:
        sql = "UPDATE safeguardSentencesReal " \
              "SET experiment = 'nouns', categ5 = '1' WHERE sentId = '"+s+"';"
        mycursor.execute(sql)
    finish(mydb, mycursor)



def validCateg(l):
    result = []
    mydb, mycursor = dbConnection()
    #sql = "SELECT sentId from safeguardSentencesReal " \
    #      "where experiment is null " \
    #      "and overallLabel = '"+l+"';"

    #sql = "SELECT sentId from safeguardSentencesReal " \
    #      "where experiment = 'valid' " \
    #      "and overallLabel = '"+l+"';"

    sql = "SELECT sentId from safeguardSentencesReal " \
          "where experiment = 'valid' and categ10 = '1'" \
          "and overallLabel = '" + l + "';"
    mycursor.execute(sql)
    res = mycursor.fetchall()
    finish(mydb, mycursor)
    for row in res:
        result.append(row[0])

    sample = random.sample(result, 5)
    mydb, mycursor = dbConnection()
    for s in sample:
        sql = "UPDATE safeguardSentencesReal " \
              "SET categ5= '1' WHERE sentId = '"+s+"';"
        mycursor.execute(sql)
    finish(mydb, mycursor)

def subclassCateg10():
    result = []
    mydb, mycursor = dbConnection()
    sql = "select sentId from safeguardSentencesReal "\
          "where experiment is null "\
          "and overallLabel = '[4]' and secondLevel = '[4.1]';"
    mycursor.execute(sql)
    res = mycursor.fetchall()
    finish(mydb, mycursor)
    for row in res:
        result.append(row[0])

    sample = random.sample(result,5)
    print(sample)
    mydb, mycursor = dbConnection()
    for s in sample:
        sql = "UPDATE safeguardSentencesReal SET experiment = 'subclass', categ10= '1' WHERE sentId = '"+s+"';"
        mycursor.execute(sql)
    finish(mydb, mycursor)

def subclassCateg5():
    result = []
    mydb, mycursor = dbConnection()
    sql = "select sentId from safeguardSentencesReal "\
          "where experiment = 'subclass' "\
          "and overallLabel = '[4]' and secondLevel = '[4.1]';"
    mycursor.execute(sql)
    res = mycursor.fetchall()
    finish(mydb, mycursor)
    for row in res:
        result.append(row[0])

    sample = random.sample(result,2)
    mydb, mycursor = dbConnection()
    for s in sample:
        sql = "UPDATE safeguardSentencesReal " \
              "SET categ5= '1' WHERE sentId = '"+s+"';"
        mycursor.execute(sql)
    finish(mydb, mycursor)

def main():
    '''
    labels = ['[0]','[1]','[2]','[3]','[4]']
    for l in labels:
        #baseCateg(l)
        #validCateg(l)
        #randomCateg(l)
        nounsCateg(l)
    '''

    #subclassCateg10()
    subclassCateg5()

main()